{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/v38/translation/en-ms/base-translation.pb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.tools.graph_transforms import TransformGraph\n",
    "from glob import glob\n",
    "tf.set_random_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_text\n",
    "import tf_sentencepiece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['base-translation.pb']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pbs = glob('*.pb')\n",
    "pbs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-6-a18cfd735cdf>:11: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.gfile.GFile.\n",
      "base-translation.pb\n"
     ]
    }
   ],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics)',\n",
    "             'fold_constants(ignore_errors=true)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']\n",
    "\n",
    "for pb in pbs:\n",
    "    input_graph_def = tf.GraphDef()\n",
    "    with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "        input_graph_def.ParseFromString(f.read())\n",
    "        \n",
    "    print(pb)\n",
    "    \n",
    "    transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                           ['Placeholder'],\n",
    "                                           ['greedy', 'beam'], transforms)\n",
    "    \n",
    "    with tf.gfile.GFile(f'{pb}.optimized', 'wb') as f:\n",
    "        f.write(transformed_graph_def.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2\n",
    "from tensorflow.tools.graph_transforms import TransformGraph\n",
    "from google.protobuf import text_format\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph_with_sess(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph\n",
    "\n",
    "def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'):\n",
    "    \"\"\"\n",
    "    Rewrite FusedBatchNorm with FusedBatchNormV2 for reserve_space_1 and reserve_space_2 in FusedBatchNorm require float32 for \n",
    "    gradient calculation (See here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fused-batch-norm)\n",
    "    \"\"\"\n",
    "    if target_type == 'fp16':\n",
    "        dtype = types_pb2.DT_HALF\n",
    "    elif target_type == 'fp64':\n",
    "        dtype = types_pb2.DT_DOUBLE\n",
    "    else:\n",
    "        dtype = types_pb2.DT_FLOAT\n",
    "    new_node = graph_def.node.add()\n",
    "    new_node.op = \"FusedBatchNormV2\"\n",
    "    new_node.name = node.name\n",
    "    new_node.input.extend(node.input)\n",
    "    new_node.attr[\"U\"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))\n",
    "    for attr in list(node.attr.keys()):\n",
    "        if attr == \"T\":\n",
    "            node.attr[attr].type = dtype\n",
    "        new_node.attr[attr].CopyFrom(node.attr[attr])\n",
    "    print(\"rewrite fused_batch_norm done!\")\n",
    "\n",
    "def convert_graph_to_fp16(model_path, as_text=False, target_type='fp16', \n",
    "                          input_name=None, output_names=None):\n",
    "    if target_type == 'fp16':\n",
    "        dtype = types_pb2.DT_HALF\n",
    "    elif target_type == 'fp64':\n",
    "        dtype = types_pb2.DT_DOUBLE\n",
    "    else:\n",
    "        dtype = types_pb2.DT_FLOAT\n",
    "    graph = load_graph_with_sess(model_path)\n",
    "    source_graph_def = graph.as_graph_def()\n",
    "    # return source_graph_def\n",
    "    for node in source_graph_def.node:\n",
    "        # fused batch norm node\n",
    "        if node.op == \"FusedBatchNorm\":\n",
    "            rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)\n",
    "            continue\n",
    "        if (\"BatchNorm\" in node.name) or ('batch_normalization' in node.name):\n",
    "            continue\n",
    "        attrs = list(node.attr.keys())\n",
    "        # keep batch norm params node\n",
    "        # replace dtype in node attr with target dtype\n",
    "        if node.op == 'convert_gradient_to_tensor_HBc3xYw22Mw':\n",
    "            node.op = 'Identity'\n",
    "            node.attr.setdefault('T')\n",
    "            node.attr['T'].type = types_pb2.DT_HALF\n",
    "            del node.attr['_disable_call_shape_inference']\n",
    "            \n",
    "        for attr in attrs:\n",
    "            # keep special node in fp32\n",
    "            if node.name in keep_fp32_node_name:\n",
    "                node.attr[attr].CopyFrom(node.attr[attr])\n",
    "                continue\n",
    "            if node.attr[attr].type == types_pb2.DT_FLOAT:\n",
    "                # modify node dtype\n",
    "                node.attr[attr].type = dtype\n",
    "            if attr == \"value\":\n",
    "                tensor = node.attr[attr].tensor\n",
    "                if tensor.dtype == types_pb2.DT_FLOAT:\n",
    "                    # if float_val exists\n",
    "                    if tensor.float_val:\n",
    "                        float_val = tf.make_ndarray(node.attr[attr].tensor)\n",
    "                        node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))\n",
    "                        continue\n",
    "                    # if tensor content exists\n",
    "                    if tensor.tensor_content:\n",
    "                        tensor_shape = [x.size for x in tensor.tensor_shape.dim]\n",
    "                        tensor_weights = tf.make_ndarray(tensor)\n",
    "                        # reshape tensor\n",
    "                        tensor_weights = np.reshape(tensor_weights, tensor_shape)\n",
    "                        tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)\n",
    "                        node.attr[attr].tensor.CopyFrom(tensor_proto)\n",
    "                        continue\n",
    "            \n",
    "    with tf.gfile.GFile(f'{model_path}.optimized', 'wb') as f:\n",
    "        f.write(source_graph_def.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_name = ['Placeholder']\n",
    "output_names = ['greedy', 'beam']\n",
    "keep_fp32_node_name = []\n",
    "\n",
    "model_path = \"base-translation.pb.optimized\"\n",
    "as_text = False\n",
    "target_type = 'fp16'\n",
    "g = convert_graph_to_fp16(model_path,\n",
    "                      as_text=as_text, \n",
    "                      target_type=target_type, \n",
    "                      input_name=input_name, output_names=output_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename, **kwargs):\n",
    "    with tf.io.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.compat.v1.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "\n",
    "    # https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091\n",
    "    # to fix import T5\n",
    "    for node in graph_def.node:\n",
    "        if node.op == 'RefSwitch':\n",
    "            node.op = 'Switch'\n",
    "            for index in xrange(len(node.input)):\n",
    "                if 'moving_' in node.input[index]:\n",
    "                    node.input[index] = node.input[index] + '/read'\n",
    "        elif node.op == 'AssignSub':\n",
    "            node.op = 'Sub'\n",
    "            if 'use_locking' in node.attr:\n",
    "                del node.attr['use_locking']\n",
    "        elif node.op == 'AssignAdd':\n",
    "            node.op = 'Add'\n",
    "            if 'use_locking' in node.attr:\n",
    "                del node.attr['use_locking']\n",
    "        elif node.op == 'Assign':\n",
    "            node.op = 'Identity'\n",
    "            if 'use_locking' in node.attr:\n",
    "                del node.attr['use_locking']\n",
    "            if 'validate_shape' in node.attr:\n",
    "                del node.attr['validate_shape']\n",
    "            if len(node.input) == 2:\n",
    "                node.input[0] = node.input[1]\n",
    "                del node.input[1]\n",
    "\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph\n",
    "\n",
    "# g_optimized = load_graph('base-translation.pb.optimized.fp16/base-translation.pb.optimized.fp16')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_optimized = load_graph('base-translation.pb.optimized.optimized')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from malaya.text.t2t import text_encoder\n",
    "\n",
    "encoder = text_encoder.SubwordTextEncoder('/home/husein/Malaya/translation-en-ms/base/en-ms.subwords')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = load_graph('base-translation.pb')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_optimized = g_optimized.get_tensor_by_name('import/import/Placeholder:0')\n",
    "greedy_optimized = g_optimized.get_tensor_by_name('import/import/greedy:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = g.get_tensor_by_name('import/Placeholder:0')\n",
    "greedy = g.get_tensor_by_name('import/greedy:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession(graph = g)\n",
    "sess_optimized = tf.InteractiveSession(graph = g_optimized)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "e = encoder.encode('Palestine, recognized officially as the State of Palestine by the United Nations and other entities, is a de jure sovereign state in Western Asia claiming the West Bank and Gaza Strip with Jerusalem as the designated capital, although its administrative center is currently located in Ramallah') + [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:23<00:00,  2.34s/it]\n"
     ]
    }
   ],
   "source": [
    "r = []\n",
    "\n",
    "for _ in tqdm(range(10)):\n",
    "    before = time.time()\n",
    "    sess.run(greedy, feed_dict = {x: [e]})\n",
    "    r.append(time.time() - before)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.3338852405548094"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "np.mean(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:51<00:00,  5.16s/it]\n"
     ]
    }
   ],
   "source": [
    "r = []\n",
    "\n",
    "for _ in tqdm(range(10)):\n",
    "    before = time.time()\n",
    "    sess_optimized.run(greedy_optimized, feed_dict = {x_optimized: [e]})\n",
    "    r.append(time.time() - before)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.159882307052612"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(r)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
